{-- 
    Code for compiling efficient pattern matches for functions with
    multi equation definitions with more than a single pattern.

    This used to be done like 

    >                                case (a,b,c) of
    > f a1 b1 c1 = e1                   (a1,b1,c1) -> e1
    > f a2 b2 c2 = e2                   (a2,b2,c2) -> e2

    and the code generator has facilities to generate code for matching that 
    wouldn't construct/deconstruct the tuple.

    But this approach fails when there are higher rank function arguments, since using
    them as tuple elements would instantiate them, and thus ruin their higher rankedness.

    See also 'caseComp'

    We follow closely 'http://research.microsoft.com/en-us/um/people/simonpj/papers/slpj-book-1987/slpj-book-1987.pdf chapter 5'
    of the famous "Implementation of Functional Programming Languages".

    -}
module frege.compiler.common.PatternCompiler where

import frege.Prelude hiding (<+>)

import Compiler.enums.Flags(TRACE7, STRICTLRPATS)
import Compiler.enums.CaseKind
import Compiler.enums.Literals

import Compiler.classes.Nice

import Compiler.types.Symbols
import Compiler.types.Global
import Compiler.types.Expression
import Compiler.types.Patterns
import Compiler.types.QNames
import Compiler.types.Positions

import Compiler.common.SymbolTable
import Compiler.common.Errors as E()
import Compiler.common.Trans

import Lib.PP(text, <+>, <>, </>)
import Data.List(sortBy, groupBy)
import Compiler.Utilities as U(freshVar)




ccSym (vsym@SymV {pos})
    | Just x ← vsym.expr = do
        nx ← x >>= ccExpr
        changeSym vsym.{expr = Just (return nx)}
    | otherwise = pure ()
ccSym sym = do
    g <- getST
    E.fatal sym.pos (text ("ccSym no SymV : " ++ sym.nice g))



ccExpr    = U.mapEx true caseComp

{-- 
    Canonicalize a pattern for the purpose of the pattern compiler.

    When we have

    > case u of pat -> ex

    but we want _pat_ to be a variable 'PVar', a 'PLit' or a constructor 'PCon'.
    For our purpose, 'PMat' counts as a special form of 'PLit' 

    We can achieve this by transforming like this, where @v@ resembles a new name.

    > p@pat  → ex                               cononicPatEx u pat (ex[p/u])
    > !p     → ex                               canonicPatEx u v | !p ← v          = ex,   if trimmed p is a variable
    >                                           canonicPatEx u p ex,        otherwise
    > (p::t) → ex                               canonicPatEx u v | (p∷t) ← v       = ex,   if trimmed p is a variable
    >                                           canonicPatEx u p ex,        otherwise (+warning)

    The first argument *must be* a 'Vbl'!
-}
canonicPatEx :: Expr → Pattern → Expr → StG (Pattern, Expr)
canonicPatEx u complex ex = 
    let -- construct | pat ← x = ex  or case x of { pat → ex }
        p |<- x = Case{ckind=CWhen, ex=x, alts=[CAlt{pat=p, ex}], typ=Nothing}
    in case complex  of
        PVar{} →  pure (complex, ex)
        PCon{} →  pure (complex, ex)
        PLit{} →  pure (complex, ex)
        PMat{} →  pure (complex, ex)
        PAt{pos, uid, var, pat} →  do
            changeSID uid u.name ex >>= canonicPatEx u pat   
        PUser{pat}
            | PVar{} ← trimmed pat →  do
                    vp ← freshVar (getrange complex)
                    canonicPatEx u vp (complex |<- vbl vp)
            | otherwise = canonicPatEx u pat ex             -- ! and ? are immaterial
        PAnn{pat}
            | PVar{} ← trimmed pat →  do
                    vp ← freshVar (getrange complex)
                    canonicPatEx u vp (complex |<- vbl vp)
            | otherwise = do
                E.warn (getrange complex) (
                    text "When defining a function with multiple equations,"
                    </> text "annotations are only allowed for variables."
                    </> text "I'm ignoring the annotation in the hope it will do no harm."
                    </> text "Please write a function signature or use a single equation"
                    </> text "if this item really needs an annotation."
                    )
                canonicPatEx u pat ex
        PConFS{} →  error "canonicPatEx: PCon not allowed here"
        
        

--- make the first pattern of a definition canonic
mkCanonic ∷ ExprT → ([Pattern],ExprT) → StG ([Pattern],ExprT)
mkCanonic u (p:ps, ex) = do
    (y,x) ← canonicPatEx u p ex
    pure (y:ps, x)
mkCanonic u sonst = pure sonst

{--
    Find case expressions of the form

    > case (u1, ...., un) of
    >    (a1, ..., an) -> e1
    >    (b1, ..., bn) -> e1

    and remove the construction/deconstruction of the tuples.

    This is important to do before typechecking, as the following example shows:

    > foo :: (forall a b.[a] → [b]) -> [c] -> [d] -> ([e],[f])
    > foo f [] ys = ([], f ys)
    > foo f xs [] = (f xs, [])      -- error: xs :: [c], expected [e]

    Because in the @Fix@ pass, all the arguments of an equation are stuffed in a tuple,
    we get a structure like above. But this instantiates the higher rank function and
    thus we get the error. See also 'https://github.com/Frege/frege/issues/273 Issue 273'

    This will work for all case expressions that scrutinize a product constructor
    (not just tuples) when the constructor is applied to variables only.
-}
caseComp expr = do
    g ← getST
    case expr of
        Case{ckind=CNormal, ex, alts}
            | App{} ← ex,
              (tcon@Con{name} : us) ← map fst (flatx ex),
              productCon name g,
              all localvariable us,
              all (proper (length us)) alts         -- do default casing later
            = do
                g ← getST
                E.logmsg TRACE7 (getpos expr) (text "caseComp: found one " <+> nicec g expr)
                let pxs = map calt alts
                e ← match (getpos expr) CNormal us pxs 
                E.logmsg TRACE7 (getpos expr) (text "caseComp: after " <+> nicec g e)
                pure (Left e)
            | otherwise = pure (Left expr)
            where
                localvariable Vbl{name=Local{}} = true
                localvariable _                 = false
                proper n CAlt{pat=PCon{pats}, ex} = length pats == n
                proper n _                        = false
                -- must be PCon!
                calt CAlt{pat, ex}     = (pat.pats, ex)
        other = pure (Left expr)



--- turn a 'Vbl' into a 'PVar'
pvar :: Expr → Pattern
pvar tv = PVar{pos=tv.pos, uid=tv.name.uid, var=tv.name.base}

--- turn a 'PVar' into a 'Vbl'
vbl :: Pattern → Expr
vbl pv = Vbl{pos=pv.pos, name=Local{uid=pv.uid, base=pv.var}, typ=Nothing}

--- extensive case printing
nicec g Case{ckind, ex, alts, typ}
     = text (if ckind == CWhen then "when" else "case")
        <+> nicet g ex <+> PP.bracket "{" (PP.stack as) "}" where
            as = map (\alt → nicet g alt.pat <+> text "→" <+> nicec g alt.ex <> text ";") alts
nicec g x = nicet g x


{--
    compile the complex case
-}
match ∷ Position → CKind → [Expr] → [([Pattern],Expr)] → StG Expr
match pos ck us ys = do
        g ← getST
        -- logit pos us ys
        case us of
            [] → case ys of 
                -- []          → pure e
                [([], ex)]  → do
                    E.logmsg TRACE7 (getpos ex) (nicec g ex)
                    pure ex
                ([], ex):exs  →
                    foldM (flip add) ex (map snd exs) 
                    --do
                    --forM_ exs $ \(_,ux) -> 
                    --    E.warn (getpos ux) (text "Alternative can never be reached.")
                    --E.logmsg TRACE7 (getpos ex) (nicec g ex)
                    --pure ex
                _ -> E.fatal pos (text "bad match parameters, turn on -x7")
            (u:xs) -> do
                ys' ← mapM (mkCanonic u) ys
                logit pos us ys'
                case () of
                     () | allvars ys' = do
                            ys'' ← mapM (varRule u) ys'
                            match pos ck xs ys'' 
                        | allcons ys' = do
                            -- e.g. [[([Cons a b,...], e1), ...], [([Nil, ...], e2]]
                            let gys = groupBy (using (_.qname . head . fst)) (sorted ys')
                            alts ← mapM (mkCalt xs) gys
                            mkCase u alts
                        | allbool ys' = do
                            let gys = groupBy (using (getLit . head . fst)) (sortlit ys')
                            alts ← mapM (mkBool xs) gys
                            mkCase u alts
                        | alllits ys' = do
                            -- let gys = groupBy (using (getLit . head . fst)) (sortlit ys')
                            alts ← mapM (mkLits xs) ys'
                            mkCase u alts
                        {- 
                            Here come the non-uniform cases.
                            Non-uniform means we have a bunch of constructors or literals
                            and at least one variable (or underscore, which is just an anonymous variable).
                            This tends to generate more complex code, and what is worse, 
                            the generated code may trigger unjustified "incomplete pattern" warnings.
                            
                            For example:
                            
                                general Nothing  []    = 42
                                general _        (_:_) = 43
                                general (Just _) []    = 44
                            
                            will generate:
                            
                                general a b = case a of 
                                    Nothing | []  ← b   = 42
                                    _   | _:_ ← b   = 43
                                        | otherwise = case a of
                                                        Just _ → case b of [] -> 44
                            
                            which works for the most general case. However, in the process the 
                            knowledge that the first argument cannot be Nothing anymore in the
                            3rd equation has been lost.
                            Hence you get a warning on the 3rd equation that Nothing is missing.
                            And in addition you get a warning that (_:_) is missing, though this
                            would have been catched by the 2nd equation, too.
                            
                            However, we can generate good code that doesn't trigger unjustified 
                            "incomplete pattern" warnings in several cases nonetheless:
                            
                            1. when the constructor reveals that we have a product type. then we can 
                               replace all the variables with that same constructor with all the 
                               subpatterns as underscores, i.e. in
                               
                                  tuple (1,x) (_:_) = x
                                  tuple _     _     = 42
                               the second clause becomes:
                                  tuple (_,_) _     = 42
                               and then this can be handled by the constructor rule.
                               
                            2. We cannot reorder clauses, since then the result of the program may
                               be difficult. But we can relax the left to right rule, since it will
                               invariably find the same right hand side (if any) than strict left to
                               right matching. However, what could change is the strictness of the 
                               function. Take for example the function "general" from above.
                               With strict left to right matching, it would be strict in the first 
                               argument since we need to check whether it is Nothing. That is, the
                               application
                                    general undefined [1,2,3]
                               would be undefined. However, consider the almost equivalent function
                               
                                   lareneg []    Nothing  = 42
                                   lareneg (_:_) _        = 43
                                   lareneg []    (Just _) = 44
                            
                               Now we can reorder the equations (since the first patterns are all refutable):
                               
                                   lareneg []    Nothing  = 42
                                   lareneg []    (Just _) = 44
                                   lareneg (_:_) _        = 43
                            
                               It becomes quite clear that the patterns are indeed complete. Anyway, if we
                               let
                               
                                    general' a b = lareneg b a
                               
                               we see that the application
                                    general' undefined [1,2,3]
                               is now 43, since the second argument will be checked first and then there
                               is no need to evaluate the first argument.
                               
                               Because we can generate much better code and avoid incomplete 
                               pattern warnings for function "lareneg" we will check the patterns in a 
                               further right column first, when that row is uniform but the 
                               first row is not, UNLESS the compiler option "-strict-pats" 
                               is in effect. 
                               We'll also give a hint about it, so the user can 
                                - leave things as they are, not caring about lazy/strict semantics
                                - leave things as they are but insist on Haskell 2010 semantics
                                  by using -strict-pats
                                - improve the definition. In fact, such a situation as in "general"
                                  is seldom unavoidable and may indicate sloppy or confused thinking. 
                                   
                        -}
                        | all (productConOrVar . head . fst) ys' = do
                            let con = head [ x | y ← ys', let x = head (fst y), singleCon x]
                                -- make [ v, ...] = ex   into [Con _ _ _, ...] = ex'
                                -- where v is substituted with u in ex'
                                v2con (alt@(PVar{pos, uid, var}:_, ex)) = do
                                    c ← mkPcon con
                                    (ps, ex') ← varRule u alt
                                    pure (c:ps, ex')
                                v2con sonst = pure sonst
                            ys'' ← mapM v2con ys'
                            match pos ck us ys''
                        | Flags.isOff g.options.flags STRICTLRPATS,
                          (i:_) ← [ i | i ← [0..length us-1], uniform i ys' ] = do
                            let toppat = head . pick i . fst . head $ ys'
                            unless (isPVar toppat) do 
                                E.hint (getpos toppat ) (
                                    text "The column of patterns whose top is "
                                    <+> nicest g toppat 
                                    </> text "will be matched out of order for efficiency reasons."
                                    </> text "If you insist on Haskell semantics, use compiler flag -strict-pats")
                            match pos ck (pick i us) [ (pick i ps,x) | (ps,x) ← ys' ]
                        | otherwise =  do
                            let diese = takeWhile (func ys' . head . fst) ys'
                                jene  = drop (length diese) ys'
                            e1 ← match pos CWhen us diese
                            e2 ← match pos ck    us jene
                            add e2 e1 
              where
                uniform ∷ Int → [([Pattern],ExprT)] → Bool
                uniform a ys = or (map is [isBool, isPVar, isPCon, isPLit])
                    where
                        is f = all f ps
                        ps   = map ((!!a) . fst) ys
                pick ∷ Int → [a] → [a]
                pick i xs = case drop i xs of
                    (y:ys) → y:(take i xs ++ ys)
                    []     → error ("pick " ++ show i ++ " " ++ show (zipWith const [0..] xs))
                -- is this a product constructor pattern?
                singleCon PCon{qname} = productCon qname g
                singleCon _           = false
                -- is this an irrefutable constructor or a variable?
                productConOrVar PVar{}  = true
                productConOrVar con     = singleCon con
                mkCase u alts = case ck of 
                    CNormal → do
                        let ex = Case{ckind=CNormal, ex=u, alts, typ=Nothing} 
                        E.logmsg TRACE7 (getpos ex) (text "mkCase NORMAL: " <+> nicec g ex)
                        pure ex
                    _ → do
                            forM_ alts (\alt →
                                E.logmsg TRACE7 (getpos alt.ex) (
                                    text "mkCase ALT: " <+> nicet g alt.pat <+> text " → "
                                    <+> nicec g alt.ex
                                    ))
                            ex ← nest alts
                            E.logmsg TRACE7 (getpos ex) (text "mkCase WHEN: " <+> nicec g ex)
                            pure ex 
                        where
                                nest [alt] = pure (Case {ckind=CWhen, ex=u, alts=[alt], typ=Nothing})
                                nest (alt:alts) = do
                                    p ← freshVar (getpos (head alts).pat)
                                    e ← nest alts
                                    pure (Case{ckind=CWhen, ex=u, alts=[alt, CAlt p e], typ=Nothing})
                                nest [] = error "nest: []"
    where
        isPVar PVar{} = true
        isPVar _      = false
        isPCon PCon{} = true
        isPCon _      = false
        isPLit PLit{} = true
        isPLit PMat{} = true
        isPLit _      = false
        isBool PLit{kind=LBool} = true
        isBool _                = false
        getLit PLit{value} = value
        getLit PMat{value} = value
        getLit _ = ""
        allvars ys = all (isPVar . head . fst) ys
        allcons ys = all (isPCon . head . fst) ys
        allbool ys = all (isBool . head . fst) ys
        alllits ys = all (isPLit . head . fst) ys
        sorted ys  = sortBy (comparing (_.qname . head . fst)) ys
        sortlit ys = sortBy (comparing (getLit  . head . fst)) ys 
        varRule u (pv:ps, ex) = do 
            nex ← Trans.changeSID pv.uid u.name ex
            pure (ps, nex)
        varRule _ _ = error "empty varRule?"
        conRule (pcon:ps, ex) = (pcon.pats ++ ps, ex)
        conRule _ = error "empty conRule?"
        litRule (_:ps, ex) = (ps, ex)
        litRule _ = error "empty litRule?"
        mkPcon pcon = do
            pats ← replicateM (length pcon.pats) (freshVar (getpos pcon))
            pure pcon.{pats}
        mkCalt ∷ [ExprT] → [([Pattern],ExprT)] → StG CAltT
        mkCalt us alts = do 
            pcon ← mkPcon (head . fst . head $ alts)
            let us' = map vbl pcon.pats ++ us
                ys' = map conRule alts
            subex ← match pos ck us' ys'
            pure CAlt{pat=pcon, ex=subex}
        mkBool ∷ [ExprT] → [([Pattern],ExprT)] → StG CAltT
        mkBool us alts = do 
            let tf    = head . fst . head $ alts
                alts' = map litRule alts
            subex ← match pos ck us alts'
            pure CAlt{pat=tf, ex=subex}
        mkLits us alt = do
            let alt' = litRule alt
            subex ← match pos ck us [alt']
            pure CAlt{pat=(head . fst) alt, ex=subex}
        func ys = case ys of
                (PVar{}:_, _):_ → isPVar
                (PCon{}:_, _):_ → isPCon
                _               → isPLit
        add ∷ ExprT → ExprT → ExprD Global
        add e2 (e1@Case{ckind=CWhen, alts=[alt]}) = do
            p ← freshVar (getpos e2)
            pure e1.{alts = [alt, CAlt p e2]}
        add e2 (e1@Case{ckind=CWhen, alts=[alt1,alt2]}) = do
            ex ← add e2 alt2.ex
            pure e1.{alts=[alt1, alt2.{ex}]}
        add e2 e1 = do
            g ← getST
            E.warn (getrange e2) (text "(some) equation(s) can never be reached:" 
                                    </> nicest g e2)
            pure e1
        logit ∷ Position → [ExprT] → [([Pattern],ExprT)] → StG ()
        logit pos xs ys = do
            g ← getST
            E.logmsg TRACE7 pos (
                    text "MATCH " <+> text (if ck==CWhen then "when" else "case") 
                                  <+> text "[" <> PP.sep "," (map (nicet g) xs)  <> text "]"
                    </> PP.stack (
                        map (\(ps,x) -> text "[" <> PP.sep "," (map (nicet g) ps)  <> text "] →"
                                        <+> nicet g x)
                            ys
                    ) 
                )
